{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import random\n",
    "import torch\n",
    "import numpy as np\n",
    "import src.comp.DataloadernoConJP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "outputs": [],
   "source": [
    "class Encoder(nn.Module):\n",
    "    def __init__(self, input_dim, emb_dim, hid_dim, n_layers=2, dropout=0.8):\n",
    "        super().__init__()\n",
    "\n",
    "        self.hid_dim = hid_dim\n",
    "        self.n_layers = n_layers\n",
    "\n",
    "        # self.embedding = nn.Embedding(input_dim, emb_dim)\n",
    "        # self.line1=nn.Linear()\n",
    "\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout = dropout)\n",
    "\n",
    "\n",
    "\n",
    "    def forward(self, src):\n",
    "\n",
    "        #src = [src len, batch size]\n",
    "\n",
    "        # embedded = self.dropout(self.embedding(src))\n",
    "        embedded = self.dropout(src)\n",
    "        #embedded = [src len, batch size, emb dim]\n",
    "\n",
    "        outputs, (hidden, cell) = self.rnn(embedded)\n",
    "\n",
    "        #outputs = [src len, batch size, hid dim * n directions]\n",
    "        #hidden = [n layers * n directions, batch size, hid dim]\n",
    "        #cell = [n layers * n directions, batch size, hid dim]\n",
    "\n",
    "        #outputs are always from the top hidden layer\n",
    "\n",
    "        return outputs,hidden, cell"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "outputs": [],
   "source": [
    "class Decoder(nn.Module):\n",
    "    def __init__(self, output_dim, emb_dim, hid_dim, n_layers=2, dropout=0.8):\n",
    "        super().__init__()\n",
    "\n",
    "        self.output_dim = output_dim\n",
    "        self.hid_dim = hid_dim\n",
    "        self.n_layers = n_layers\n",
    "\n",
    "        # self.embedding = nn.Embedding(output_dim, emb_dim)\n",
    "\n",
    "        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout = dropout)\n",
    "\n",
    "        self.fc_out = nn.Linear(hid_dim, output_dim)\n",
    "\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "    def forward(self, input):\n",
    "\n",
    "        #input = [batch size]\n",
    "        #hidden = [n layers * n directions, batch size, hid dim]\n",
    "        #cell = [n layers * n directions, batch size, hid dim]\n",
    "\n",
    "        #n directions in the decoder will both always be 1, therefore:\n",
    "        #hidden = [n layers, batch size, hid dim]\n",
    "        #context = [n layers, batch size, hid dim]\n",
    "\n",
    "        # input = input.unsqueeze(0)\n",
    "\n",
    "        #input = [1, batch size]\n",
    "\n",
    "        embedded = self.dropout(input)\n",
    "\n",
    "        #embedded = [1, batch size, emb dim]\n",
    "\n",
    "        output, (hidden, cell) = self.rnn(embedded)\n",
    "\n",
    "        #output = [seq len, batch size, hid dim * n directions]\n",
    "        #hidden = [n layers * n directions, batch size, hid dim]\n",
    "        #cell = [n layers * n directions, batch size, hid dim]\n",
    "\n",
    "        #seq len and n directions will always be 1 in the decoder, therefore:\n",
    "        #output = [1, batch size, hid dim]\n",
    "        #hidden = [n layers, batch size, hid dim]\n",
    "        #cell = [n layers, batch size, hid dim]\n",
    "\n",
    "        # prediction = self.fc_out(output.squeeze(0))\n",
    "        prediction = output\n",
    "        #prediction = [batch size, output dim]\n",
    "\n",
    "        return prediction"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "outputs": [],
   "source": [
    "class Seq2Seq(nn.Module):\n",
    "    def __init__(self, encoder, decoder):\n",
    "        super().__init__()\n",
    "\n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "\n",
    "\n",
    "        assert encoder.hid_dim == decoder.hid_dim, \\\n",
    "            \"Hidden dimensions of encoder and decoder must be equal!\"\n",
    "        assert encoder.n_layers == decoder.n_layers, \\\n",
    "            \"Encoder and decoder must have equal number of layers!\"\n",
    "\n",
    "    def forward(self, src, trg, teacher_forcing_ratio = 0.5):\n",
    "\n",
    "        #src = [src len, batch size]\n",
    "        #trg = [trg len, batch size]\n",
    "        #teacher_forcing_ratio is probability to use teacher forcing\n",
    "        #e.g. if teacher_forcing_ratio is 0.75 we use ground-truth inputs 75% of the time\n",
    "\n",
    "        batch_size = trg.shape[1]\n",
    "        trg_len = trg.shape[0]\n",
    "        trg_vocab_size = self.decoder.output_dim\n",
    "\n",
    "        #tensor to store decoder outputs\n",
    "        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size)\n",
    "\n",
    "        #last hidden state of the encoder is used as the initial hidden state of the decoder\n",
    "        hidden, cell = self.encoder(src)\n",
    "\n",
    "        #first input to the decoder is the <sos> tokens\n",
    "        input = trg[0,:]\n",
    "\n",
    "        for t in range(1, trg_len):\n",
    "\n",
    "            #insert input token embedding, previous hidden and previous cell states\n",
    "            #receive output tensor (predictions) and new hidden and cell states\n",
    "            output, hidden, cell = self.decoder(input, hidden, cell)\n",
    "\n",
    "            #place predictions in a tensor holding predictions for each token\n",
    "            outputs[t] = output\n",
    "\n",
    "            #decide if we are going to use teacher forcing or not\n",
    "            teacher_force = random.random() < teacher_forcing_ratio\n",
    "\n",
    "            #get the highest predicted token from our predictions\n",
    "            top1 = output.argmax(1)\n",
    "\n",
    "            #if teacher forcing, use actual next token as next input\n",
    "            #if not, use predicted token\n",
    "            input = trg[t] if teacher_force else top1\n",
    "\n",
    "        return outputs"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "outputs": [],
   "source": [
    "INPUT_DIM = 20\n",
    "OUTPUT_DIM = 20\n",
    "ENC_EMB_DIM = 64\n",
    "DEC_EMB_DIM = 64\n",
    "HID_DIM = 32\n",
    "N_LAYERS = 2\n",
    "ENC_DROPOUT = 0.5\n",
    "DEC_DROPOUT = 0.5\n",
    "\n",
    "enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT)\n",
    "dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT)\n",
    "\n",
    "model = Seq2Seq(enc, dec)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adam(model.parameters())"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "outputs": [],
   "source": [
    "def train(model, iterator, optimizer, criterion, clip):\n",
    "\n",
    "    model.train()\n",
    "\n",
    "    epoch_loss = 0\n",
    "\n",
    "    for i, batch in enumerate(iterator):\n",
    "\n",
    "        src = batch.src\n",
    "        trg = batch.trg\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        output = model(src, trg)\n",
    "\n",
    "        #trg = [trg len, batch size]\n",
    "        #output = [trg len, batch size, output dim]\n",
    "\n",
    "        output_dim = output.shape[-1]\n",
    "\n",
    "        output = output[1:].view(-1, output_dim)\n",
    "        trg = trg[1:].view(-1)\n",
    "\n",
    "        #trg = [(trg len - 1) * batch size]\n",
    "        #output = [(trg len - 1) * batch size, output dim]\n",
    "\n",
    "        loss = criterion(output, trg)\n",
    "\n",
    "        loss.backward()\n",
    "\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)\n",
    "\n",
    "        optimizer.step()\n",
    "\n",
    "        epoch_loss += loss.item()\n",
    "\n",
    "    return epoch_loss / len(iterator)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "outputs": [],
   "source": [
    "def evaluate(model, iterator, criterion):\n",
    "\n",
    "    model.eval()\n",
    "\n",
    "    epoch_loss = 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "\n",
    "        for i, batch in enumerate(iterator):\n",
    "\n",
    "            src = batch.src\n",
    "            trg = batch.trg\n",
    "\n",
    "            output = model(src, trg, 0) #turn off teacher forcing\n",
    "\n",
    "            #trg = [trg len, batch size]\n",
    "            #output = [trg len, batch size, output dim]\n",
    "\n",
    "            output_dim = output.shape[-1]\n",
    "\n",
    "            output = output[1:].view(-1, output_dim)\n",
    "            trg = trg[1:].view(-1)\n",
    "\n",
    "            #trg = [(trg len - 1) * batch size]\n",
    "            #output = [(trg len - 1) * batch size, output dim]\n",
    "\n",
    "            loss = criterion(output, trg)\n",
    "\n",
    "            epoch_loss += loss.item()\n",
    "\n",
    "    return epoch_loss / len(iterator)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "outputs": [],
   "source": [
    "data_params: dict = {'batch_size': 300,\n",
    "                                  'shuffle': True,\n",
    "                                  'num_workers': 6}\n",
    "from comp.DataloadernoConJP import DataloadernoConJP as Dataloader"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load lyrics_seq\n",
      "load discrete_attr_seq\n"
     ]
    }
   ],
   "source": [
    "training_set = Dataloader('2022-01-15_18_30_13_embeddings10_vector.pt',\n",
    "                                      '2022-01-15_18_30_13_vocabulary10_lookup.json',\n",
    "                                      20)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "outputs": [],
   "source": [
    "from torch.utils import data\n",
    "train_data_iterator = data.DataLoader(training_set, **data_params)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "outputs": [],
   "source": [
    "lyrics_seq = None\n",
    "                # cont_val_seq = data[1].to(self.device)\n",
    "discrete_val_seq = None\n",
    "\n",
    "for num_steps_D, data in enumerate(train_data_iterator):\n",
    "    lyrics_seq = data[0]\n",
    "    discrete_val_seq = data[1]\n",
    "    break"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "outputs": [],
   "source": [
    "en=Encoder(0,20,100,n_layers=2)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "outputs": [],
   "source": [
    "outputs,hidden,cell=en(lyrics_seq)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "outputs": [],
   "source": [
    "de=Decoder(3,100,3)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "outputs": [],
   "source": [
    "out_2=de(outputs)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}